# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import xlrd

DATA_FILE = "fire_theft.xls"


def huber_loss(labels, predictions, delta=0.1):
    whether = tf.abs(labels - predictions)
    small = 0.5 * tf.square(whether)
    large = delta * whether - 0.5 * tf.square(delta)
    condition = tf.less(whether, delta)
    return tf.where(condition, small, large)

# Step 1: read in data from the .xls file
book = xlrd.open_workbook(DATA_FILE, encoding_override="utf-8")
sheet = book.sheet_by_index(0)
data = np.asarray([sheet.row_values(i) for i in range(1, sheet.nrows)])
n_samples = sheet.nrows - 1  # n_samples:42

# Step 2: create placeholders for input X (number of fire) and label Y (number of theft)
X = tf.placeholder(tf.float32, name="X")
Y = tf.placeholder(tf.float32, name="Y")

# Step 3: create weight and bias, initialized to 0
w = tf.Variable(0.0, name="weights_1")
u = tf.Variable(0.0, name="weights_2")
b = tf.Variable(0.0, name="bias")

# Step 4: construct model to predict Y (number of theft) from the number of fire
# Y_predicted = X * w + b
Y_predicted = X * X * w + X*u  + b
# Step 5: use the square error as the loss function
# loss = tf.square(Y - Y_predicted, name="loss")#可以换成huber等其他函数
huberloss = huber_loss(Y, Y_predicted, 0.01)
# Step 6: using gradient descent with learning rate of 0.01 to minimize loss
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.00001).minimize(huberloss)  # 返回为一个优化更新后的var_list
# 自定生成一个gradients tensor和一个梯度下降优化器的op，该op的输入是weights,bias和gradients.

with tf.Session() as sess:
    # Step 7: initialize the necessary variables, in this case, w and b
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter('./my_graph/03/liner_reg', sess.graph)
    # Step 8: train the model
    for i in range(5000):  # run 100 epochs
        total_loss = 0
        for x, y in data:
            # Session runs train_op to minimize loss
            _, l = sess.run([optimizer, huberloss], feed_dict={X: x, Y: y})  #每次返回新的参数列表，并不断在执行RUN中进行迭代更新。
            total_loss += l
        print(total_loss)
    # Step 9: output the values of w and b
    # w_value, b_value = sess.run([w, b])
    w_value, u_value, b_value = sess.run([w, u, b])

X, Y = data.T[0], data.T[1]  #T是矩阵转置的意思
plt.plot(X, Y, 'bo', label='Real data')
plt.plot(X, X * X * w_value + X * u_value + b_value, 'r', label='Predicted data')
plt.legend()
plt.show()
